import java.util.HashMap;

class Solution {
    public int findMaxLength(int[] nums) {
        HashMap<Integer, Integer> hash = new HashMap<>();
        hash.put(0, -1);
        int sum = 0;
        int ret = 0;
        for (int i = 0; i < nums.length; i++) {
            sum += (nums[i] == 0 ? -1 : 1);
            if (hash.containsKey(sum)) {
                ret = Math.max(ret, i - hash.get(sum));
            } else {
                hash.put(sum, i);
            }
        }
        return ret;
    }
}